import warnings
warnings.filterwarnings("ignore")
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib import animation
import numpy as np
import xarray as xr
from IPython.display import HTML
from tool.dataset import NetCDFDataset
from tool.loss import RMSELoss
from tool.utils import Util, ImgUtil
file_hr = './data/eta_eulerian_resolution_20.0km.nc'
file_lr = './data/eta_eulerian_resolution_100.0km.nc'
# loading outputs from shallow water model and transposing dimensions
ds_hr = xr.open_mfdataset(file_hr).transpose('time', 'y', 'x')
ds_lr = xr.open_mfdataset(file_lr).transpose('time', 'y', 'x')
da_hr = ds_hr.rename({'__xarray_dataarray_variable__':'eta'}).eta
da_lr = ds_lr.rename({'__xarray_dataarray_variable__':'eta'}).eta
# interpolation - from 100x100 to 20x20
da_hr_interp = da_hr.interp(y=da_lr.y, x=da_lr.x, method='linear')
da_hr_interp_v = da_hr_interp.values
# filtering only time steps present in coarse resolution
da_hr_interp_subset = da_hr_interp.isel(time=slice(0, da_hr_interp.time.size - 1, 5))
da_lr_v = da_lr.values
da_hr_interp_subset_v = da_hr_interp_subset.values
da_hr_interp_subset
da_lr
def load_predictions(data, hora):
folder = '/home/guinelli/mestrado/stconvs2s/output/full-dataset/predictions/stconvs2s-r/'
inputs = torch.load(folder + 'shalow_water_step20_4_' + data + '-' + hora + '_input.pt')
targets = torch.load(folder + 'shalow_water_step20_4_' + data + '-' + hora + '_target.pt')
predictions = torch.load(folder + 'shalow_water_step20_4_' + data + '-' + hora + '_predictions.pt')
return inputs, targets, predictions
inputs, targets, predictions = load_predictions(data='20210303', hora='200922')
np_inputs = inputs.cpu().numpy()
np_targets = targets.cpu().numpy()
np_predictions = predictions.cpu().numpy()
class RMSELoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.mse = torch.nn.MSELoss()
def forward(self,yhat,y):
return torch.sqrt(self.mse(yhat,y) + 1e-6 )
def get_loss(target, output):
rmse_loss, mae_loss, x = [], [], []
#criterium = torch.nn.MSELoss()
criterium = RMSELoss()
for i in range(target.shape[2]):
yhat, y = output[:,:,i,:,:], target[:,:,i,:,:]
yhat, y = torch.from_numpy(yhat), torch.from_numpy(y)
rmse_loss.append(criterium(yhat, y))
mae_loss.append(F.l1_loss(yhat, y))
return rmse_loss, mae_loss
rmse_loss, mae_loss = get_loss(da_hr_interp_subset_v[None, None], da_lr_v[None, None])
fim = 80
inicio = fim-1
c = RMSELoss()
print(RMSELoss()(predictions[:,2:], targets[:,2:]))
print(F.l1_loss(predictions[:,2:], targets[:,2:]))
rmse_loss_2, mae_loss_2 = get_loss(np_predictions[:,2:], np_targets[:,2:])
#rmse_loss_2, mae_loss_2 = get_loss(np_predictions[inicio:fim], np_targets[inicio:fim])
fig, ax = plt.subplots(figsize=(10,3))
ax.bar(np.arange(1,21)-0.2, rmse_loss[:40:2], width=0.4, label="RMSE Low Resolution")
ax.bar(np.arange(1,21)+0.2, rmse_loss_2, width=0.4, label="RMSE NN")
ax.set_xticks(np.arange(1,21))
plt.legend(loc="upper left")
length = 100
shift = 0
f = shift + length
plt.figure(figsize=(10,3))
plt.plot(np.arange(shift, f), da_hr_interp_subset_v[shift:f,0,0], "-b", label="High Resolution")
plt.plot(np.arange(shift, f), da_lr_v[shift:f,0,0], "-g", label="Low Resolution")
plt.legend(loc="upper left")
plt.show()
length = 300
shift = 0
f = shift + length
plt.figure(figsize=(10,3))
plt.plot(np.arange(shift, f), da_hr_interp_subset_v[shift:f,5,5], "-b", label="High Resolution")
plt.plot(np.arange(shift, f), da_lr_v[shift:f,5,5], "-g", label="Low Resolution")
plt.legend(loc="upper left")
plt.show()
length = 100
shift = 5000
f = shift + length
plt.figure(figsize=(10,3))
plt.plot(np.arange(shift, f), da_hr_interp_subset_v[shift:f,10,10], "-b", label="High Resolution")
plt.plot(np.arange(shift, f), da_lr_v[shift:f,10,10], "-g", label="Low Resolution")
plt.legend(loc="upper left")
plt.show()
length = 300
shift = 5000
f = shift + length
plt.figure(figsize=(10,3))
plt.plot(np.arange(shift, f), da_hr_interp_subset_v[shift:f,5,5], "-b", label="High Resolution")
plt.plot(np.arange(shift, f), da_lr_v[shift:f,5,5], "-g", label="Low Resolution")
plt.legend(loc="upper left")
plt.show()
length = np_targets.shape[2]
plt.figure(figsize=(10,3))
plt.plot(np.arange(1,length+1), np_targets[0,2,:,10,10], "-b", label="target")
plt.plot(np.arange(1,length+1), np_predictions[0,2,:,10,10], "-g", label="output")
plt.legend(loc="upper left")
#plt.ylim(0.04, 0.06)
plt.show()
length = np_targets.shape[2]
plt.figure(figsize=(10,3))
plt.plot(np.arange(1,length+1), np_targets[0,2,:,15,15], "-b", label="target")
plt.plot(np.arange(1,length+1), np_predictions[0,2,:,15,15], "-g", label="output")
plt.legend(loc="upper left")
plt.show()
def load_checkpoint(filename, model_name, dataset_type=None, model=None, optimizer=None):
if not(os.path.isabs(filename)):
filename = os.path.join('/home/guinelli/mestrado/stconvs2s/output', dataset_type, 'checkpoints', model_name.lower(), filename)
epoch, loss = 0.0, 0.0
if os.path.isfile(filename):
checkpoint = torch.load(filename)
name = os.path.basename(filename)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f'=> Loaded checkpoint {name} (best epoch: {epoch}, validation rmse: {loss:.4f})')
else:
print(f'=> No checkpoint found at {filename}')
return model
from model.stconvs2s import STConvS2S_R
# same parameters from training
x_shape = [1, 3, 5, 20, 20]
num_layers = 3
hidden_dim = 32
kernel_size = 5
device = torch.device('cuda')
dropout_rate = 0
step = 20
opt_params = {
'lr': 0.001,
'alpha': 0.9,
'eps': 1e-6
}
model = STConvS2S_R(x_shape, num_layers, hidden_dim, kernel_size, device, dropout_rate, step)
model.to(device)
optimizer = torch.optim.RMSprop(model.parameters(), **opt_params)
# files with checkpoints from each training epoch
filename = 'shalow_water_step20_4_20210303-200922.pth.tar'
dataset_type = 'full-dataset'
model_name = 'stconvs2s-r'
# loading weigths from best epoch
model = load_checkpoint(filename, model_name, dataset_type, model, optimizer)
ds = xr.open_mfdataset('/home/guinelli/mestrado/stconvs2s/data/dataset-shallow-water-3ch-20.0km-yseq20-ystep10.nc')
ds_x = torch.from_numpy(ds.x.values).float().permute(0, 4, 1, 2, 3)
ds_y = torch.from_numpy(ds.y.values).float().permute(0, 4, 1, 2, 3)
x_np = ds_x.cpu().numpy()
y_np = ds_y.cpu().numpy()
# getting last 20% of the dataset for test
start_i = int(np.ceil(ds.x.shape[0] * 0.8))
start_i
# generating input and expected output (target)
ds = x_np[start_i::10,:,1][:245].transpose(1,0,2,3)
x, y = ds[:,0:5][None], ds[:,5:][None]
# aprox. 12min * 240 (aprox. 2 days)
x.shape, y.shape
def make_predictions(model, initial_input, steps, device):
model.eval()
predictions = []
with torch.no_grad():
i_input = initial_input
i_input = i_input.to(device)
for i in range(steps):
output = model(i_input)
predictions.append(output)
i_input = output[:,:,-5:]
return predictions
pred = make_predictions(model=model, initial_input=torch.tensor(x, dtype=torch.float), steps=12, device=device)
# concatenating predictions
pred = torch.cat(pred, dim=2)
np_pred = pred.cpu().numpy()
np_pred.shape
# organizing data to plot a gif
diff = np_pred[0,2] - y[0,2]
data = [y[0,2],
np_pred[0,2],
diff]
titles = ['Target', 'Output', 'Diff']
%%capture
# creating gif animation
%matplotlib notebook
g = ImgUtil(data, titles, np.min(data), np.max(data))
# frames indicates how many steps should be printed
anim = g.get_animation_gif(frames=240, cmap='viridis')
HTML(anim.to_jshtml())
type(np_pred[0,2,0,0,0]),type(np.float32(y)[0,0,0,0,0])
rmse_nn, mae_nn = get_loss(np_pred[:,2:], y[:,2:])
x = np.arange(1,len(rmse_loss)+1)
rmse_lr, mae_lr = get_loss(da_hr_interp_subset_v[::2][None, None], da_lr_v[::2][None, None])
len(rmse_loss),len(x)
%matplotlib inline
width = 0.4
fig_bar = plt.figure(figsize=(15,3))
ax = fig_bar.add_axes([0,0,1,1])
ax.bar(np.arange(1,len(x[:200])+1)-width/2, rmse_nn[:200], width=0.4, label='RMSE NN')
ax.bar(np.arange(1,len(x[:200])+1)+width/2, rmse_lr[:200], width=0.4, label='RMSE Low Res')
ax.set_xticks(x[:200])
plt.ylabel('error')
plt.xlabel('prediction')
plt.legend(loc="upper center")
%matplotlib inline
width = 0.4
fig_bar = plt.figure(figsize=(15,3))
ax = fig_bar.add_axes([0,0,1,1])
ax.bar(np.arange(1,len(x[:80])+1)-width/2, rmse_nn[:80], width=0.4, label='RMSE NN')
ax.bar(np.arange(1,len(x[:80])+1)+width/2, rmse_lr[:80], width=0.4, label='RMSE Low Res')
#ax.plot(np.arange(1,len(x[:200])+1)-width/2, rmse_nn[:200], width=0.4, label='RMSE NN')
#ax.plot(np.arange(1,len(x[:200])+1)+width/2, rmse_lr[:200], width=0.4, label='RMSE Low Res')
ax.set_xticks(x[:80])
plt.ylabel('error')
plt.xlabel('prediction')
plt.legend(loc="upper center")
length = 240
plt.figure(figsize=(10,3))
plt.plot(np.arange(1,length+1), y[0,2,:length,10,10], "-b", label="target")
plt.plot(np.arange(1,length+1), np_pred[0,2,:length,10,10], "-g", label="output")
plt.legend(loc="upper left")
#plt.ylim(0.036, 0.154)
plt.show()
length = 200
plt.figure(figsize=(10,3))
plt.plot(np.arange(1,length+1), y[0,2,:length,15,15], "-b", label="target")
plt.plot(np.arange(1,length+1), np_pred[0,2,:length,15,15], "-g", label="output")
plt.legend(loc="upper left")
#plt.ylim(0.06, 0.04)
plt.show()
da_hr_interp_subset_v.shape, da_lr_v.shape, da_hr[1::5].shape
sum_nn = np.squeeze(np_pred[:,2:]).sum(1).sum(1)
sum_lr = da_lr_v[::2].sum(1).sum(1)
sum_hr = da_hr_interp_subset_v[::2].sum(1).sum(1)
sum_hr2 = da_hr.values[::10].sum(1).sum(1) / 25
length = len(sum_nn)
plt.figure(figsize=(10,3))
plt.plot(np.arange(1,length+1), sum_nn[:length], label="NN")
plt.plot(np.arange(1,length+1), sum_lr[:length], label="Low Res")
plt.plot(np.arange(1,length+1), sum_hr[:length], label="High Res")
plt.plot(np.arange(1,length+1), sum_hr2[:length], label="High Res 2")
plt.legend(loc="upper left")
#plt.ylim(0.06, 0.04)
plt.show()